{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Compression insights: interpreting the results\n",
    "\n",
    "After pruning a model, you may want to examine the statistics of the pruned model.  This notebook tries to provide some tools to help you.\n",
    "\n",
    "We compare a dense ResNet20-Cifar model with a model that was trained using SSL [(Structured Sparsity Learning)](#Wen-et-al-2016), with layer-wise regualrization to remove 5 layers.\n",
    "\n",
    "## Table of Contents\n",
    "\n",
    "1. [Load dense and sparse models of your network](#Load-dense-and-sparse-models-of-your-network)\n",
    "2. [Examine some weights tensor shapes sizes and statistics of the sparse model](#Examine-some-weights-tensor-shapes-sizes-and-statistics-of-the-sparse-model)<br>\n",
    "    2.1. [View the element-wise sparsity of each layer](#View-the-element-wise-sparsity-of-each-layer)<br>\n",
    "    2.2. [Compare the sparsity of different weight tensor sub-structures](#Compare-the-sparsity-of-different-weight-tensor-sub-structures)\n",
    "3. [Compare the distributions of the weight tensors in the sparse and dense models](#Compare-the-distributions-of-the-weight-tensors-in-the-sparse-and-dense-models)\n",
    "4. [Visualize kernel sparsity](#Visualize-kernel-sparsity)<br>\n",
    "    4.1. [Fully-connected layers](#Fully-connected-layers)<br>\n",
    "    4.2. [Convolution layers](#Convolution-layers)<br>\n",
    "5. [Visualize channel-wise sparsity](#Visualize-channel-wise-sparsity)\n",
    "6. [Kernel L1-norm distributions](#Kernel-L1-norm-distributions)\n",
    "7. [Performance metrics](#Performance-metrics)<br>\n",
    "    7.1. [Compare sparse model compute vs. dense model compute](#Compare-sparse-model-compute-vs.-dense-model-compute)<br>\n",
    "    7.2. [Now compare dense ResNet20-Cifar to ResNet20-Cifar which has 5 layers removed using SSL](#Now-compare-dense-ResNet20-Cifar-to-ResNet20-Cifar-which-has-5-layers-removed-using-SSL)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib \n",
    "\n",
    "# Load some common jupyter code\n",
    "%run distiller_jupyter_helpers.ipynb\n",
    "import ipywidgets as widgets\n",
    "from ipywidgets import interactive, interact, Layout\n",
    "from distiller.models import create_model\n",
    "from distiller.apputils import *\n",
    "import pandas as pd\n",
    "from torch.autograd import Variable"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load dense and sparse models of your network\n",
    "We find it useful to compare the dense and sparse models, and understand the differences.\n",
    "\n",
    "We believe that the accuracy results below can be improved, given that we didn't expend much effort in getting the results below.\n",
    "\n",
    "|       Model     |  Top1  |  Top5  |  Loss | MACs  | Footprint |\n",
    "| --------------- | ------ | ------ | ----- | ----- | --------- |\n",
    "| Baseline        | 91.780 | 99.710 | 0.376 | 100.0%| 100.0%    |\n",
    "| Channel removal | 91.740 | 99.700 | 0.332 | 73.4% | 91.22%    |\n",
    "| Layer removal   | 91.020 | 99.670 | 0.349 | 71.1% | 93.20%    |"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Choose to examine some models\n",
    "resnet20_dense = create_model(False, 'cifar10', 'resnet20_cifar')\n",
    "checkpoint_file = \"../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar\" \n",
    "load_checkpoint(resnet20_dense, checkpoint_file)\n",
    "\n",
    "resnet20_layer_reg = create_model(False, 'cifar10', 'resnet20_cifar')\n",
    "checkpoint_file = \"../examples/ssl/checkpoints/checkpoint_trained_4D_regularized_5Lremoved.pth.tar\"\n",
    "load_checkpoint(resnet20_layer_reg, checkpoint_file);\n",
    "\n",
    "resnet20_channel_reg = create_model(False, 'cifar10', 'resnet20_cifar')\n",
    "checkpoint_file = \"../examples/ssl/checkpoints/checkpoint_trained_ch_regularized_dense.pth.tar\"\n",
    "load_checkpoint(resnet20_channel_reg, checkpoint_file);\n",
    "\n",
    "models_dict = {'Dense': resnet20_dense, 'Sparse': resnet20_layer_reg}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Examine some weights tensor shapes, sizes and statistics of the sparse model\n",
    "\n",
    "You can extract sparsity information from a model using ```distiller.weights_sparsity_summary```.<br>\n",
    "The columns with ```(%)``` in their name usually show the fraction of the tensor that is sparse (in percentage), but you can choose to display their density (percentage of non-zero, NNZ, elements). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def view_data(what, model_choice):\n",
    "    df_sparsity = distiller.weights_sparsity_summary(models_dict[model_choice])\n",
    "    # Remove these two columns which contains uninteresting values\n",
    "    df_sparsity = df_sparsity.drop(['Cols (%)', 'Rows (%)'], axis=1)\n",
    "    \n",
    "    if what == 'Density':\n",
    "        for granularity in ['Fine (%)', 'Ch (%)', '2D (%)', '3D (%)']:\n",
    "            df_sparsity[granularity] = 100 - df_sparsity[granularity]\n",
    "    display(df_sparsity)\n",
    "\n",
    "display_radio = widgets.RadioButtons(options=['Sparsity', 'Density'], value='Sparsity', description='Display:')\n",
    "model_dropdown = widgets.Dropdown(description='Model:', options=models_dict.keys(), value='Sparse')\n",
    "interact(view_data, what=display_radio, model_choice=model_dropdown);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### View the element-wise sparsity of each layer\n",
    "\n",
    "It is appearant that we removed 5 entire layers, all of which have only a few parameters.  If we wanted to reduce the footprint, then this is not a good choice of layers to remove.  Later on in the notebook we look at the compute metric and see the benefit of removing these layers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "matplotlib.rcParams.update({'font.size': 22})\n",
    "\n",
    "df_sparsity = distiller.weights_sparsity_summary(resnet20_layer_reg)\n",
    "df2_sparsity = df_sparsity[['NNZ (dense)', 'NNZ (sparse)']]\n",
    "ax = df2_sparsity.iloc[0:-1].plot(kind='bar', figsize=[30,10], title=\"Weights footprint: Sparse vs. Dense\\n(element-wise)\")\n",
    "ax.set_xticklabels(df_sparsity.Name, rotation=90);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Compare the sparsity of different weight tensor sub-structures\n",
    "\n",
    "Use the widget below to select the different structures who's sparsity you want to compare.<br>\n",
    "You can choose multiple structures, to compare between them.\n",
    "Ch (%) - this is the percentage of channels that are sparse<br>\n",
    "2D (%) - this is the percentage of kernels that are sparse<br>\n",
    "3D (%) - this is the percentage of filters that are sparse<br>\n",
    "Fine (%) - this is the percentage of elements that are sparse<br>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def view(cols, what): \n",
    "    df_sparsity = distiller.weights_sparsity_summary(resnet20_layer_reg)\n",
    "    if what == 'Density':\n",
    "        for granularity in ['Fine (%)', 'Ch (%)', '2D (%)', '3D (%)']:\n",
    "            df_sparsity[granularity] = 100 - df_sparsity[granularity]\n",
    "    \n",
    "    df2 = df_sparsity[list(cols)]\n",
    "    ax = df2.iloc[0:-1].plot(kind='bar', figsize=[30,10], title=\"Comparing structural sparsity\", grid=True)\n",
    "    ax.set_ylabel(\"% Sparsity\")\n",
    "    ax.set_xticklabels(df_sparsity.Name, rotation=90);\n",
    "\n",
    "\n",
    "items = ['Ch (%)', '2D (%)', '3D (%)', 'Fine (%)']\n",
    "cols_select = widgets.SelectMultiple(options=items, value=[items[1]])\n",
    "display_radio = widgets.RadioButtons(options=['Sparsity', 'Density'], value='Sparsity', description='Display:')\n",
    "interactive(view, cols=cols_select, what=display_radio)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Compare the distributions of the weight tensors in the sparse and dense models\n",
    "\n",
    "We trained the sparse model with layer-wise regularization, but we can see that this affected the distibution of some of the layers.\n",
    "\n",
    "Look at ```module.layer1.2.conv1.weight``` at an example of a weights tensor that is more \"compact\" after regularization."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ipywidgets import *\n",
    "import bqplot.pyplot as bqplt\n",
    "from bqplot import Bars, Axis, Figure, LinearScale\n",
    "\n",
    "xs = LinearScale()\n",
    "ys = LinearScale()\n",
    "\n",
    "xs_dense = LinearScale()\n",
    "ys_dense = LinearScale()\n",
    "\n",
    "nbins = 100\n",
    "sparse_tensor = flatten(next (iter (resnet20_layer_reg.state_dict().values())))\n",
    "sparse_hist, sparse_edges = np.histogram(sparse_tensor, bins=nbins)\n",
    "sparse_bar = Bars(x=sparse_edges, y=[sparse_hist], scales={'x': xs, 'y': ys}, padding=0.2, type='grouped')\n",
    "\n",
    "\n",
    "dense_tensor = flatten(next (iter (resnet20_dense.state_dict().values())))\n",
    "dense_hist, dense_edges = np.histogram(dense_tensor, bins=nbins)\n",
    "dense_bar = Bars(x=dense_edges, y=[dense_hist], scales={'x': xs_dense, 'y': ys_dense}, padding=0.2, type='grouped')\n",
    "\n",
    "xax = Axis(scale=xs)\n",
    "yax = Axis(scale=ys, orientation='vertical',  grid_lines='solid')\n",
    "\n",
    "xax_dense = Axis(scale=xs_dense)\n",
    "yax_dense = Axis(scale=ys_dense, orientation='vertical', grid_lines='solid')\n",
    "\n",
    "f = Figure(marks=[sparse_bar], axes=[xax, yax], animation_duration=1000, title=\"Sparse\")\n",
    "f2 = Figure(marks=[dense_bar], axes=[xax_dense, yax_dense], animation_duration=1000, title=\"Dense\")\n",
    "\n",
    "shape =  distiller.size2str(next (iter (resnet20_layer_reg.state_dict().values())).size())\n",
    "param_info = widgets.Text(value=shape, description='shape:', disabled=True)\n",
    "\n",
    "params_names = conv_fc_param_names(resnet20_layer_reg)\n",
    "\n",
    "weights_dropdown = Dropdown(description='weights', options=params_names)\n",
    "\n",
    "def update(*args):\n",
    "    sparse_model = resnet20_layer_reg\n",
    "    dense_model = resnet20_dense\n",
    "    \n",
    "    param_name = weights_dropdown.value\n",
    "    \n",
    "    sparse_tensor = flatten(sparse_model.state_dict()[param_name])\n",
    "    sparse_hist, sparse_edges = np.histogram(sparse_tensor, bins=nbins, density=False)\n",
    "    sparse_bar.x = sparse_edges\n",
    "    sparse_bar.y = sparse_hist\n",
    "    \n",
    "    dense_tensor = flatten(dense_model.state_dict()[param_name])\n",
    "    dense_hist, dense_edges = np.histogram(dense_tensor, bins=nbins, density=False)\n",
    "    dense_bar.x = dense_edges\n",
    "    dense_bar.y = dense_hist\n",
    "    shape =  distiller.size2str(dense_model.state_dict()[param_name].size())\n",
    "    param_info.value = shape\n",
    "\n",
    "weights_dropdown.observe(update, 'value')\n",
    "\n",
    "VBox([HBox([weights_dropdown, param_info]), HBox([f, f2])])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualize kernel sparsity\n",
    "\n",
    "### Fully-connected layers\n",
    "Classifiers have 2D-shaped weights tensors."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Color normalization - we want all parameters to share the same color ranges in the kernel plots,\n",
    "# so we need to find the min and max across all weight tensors in the model.\n",
    "# As a last step, we also center the colorbar so that 0 is white - this makes it easier to see the sparsity.\n",
    "extreme_vals = [list((p.max(), p.min())) for param_name, p in resnet20_layer_reg.state_dict().items()  \n",
    "                if (p.dim()>1) and (\"weight\" in param_name)]\n",
    "\n",
    "flat = [item for sublist in extreme_vals for item in sublist]\n",
    "center = (max(flat) + min(flat)) / 2\n",
    "sparse_max = max(flat) - center\n",
    "sparse_min = min(flat) - center\n",
    "\n",
    "params_names = fc_param_names(resnet20_layer_reg)\n",
    "\n",
    "def view_weights(pname):\n",
    "    weights = resnet20_layer_reg.state_dict()[pname]\n",
    "    aspect_ratio = weights.size(0) / weights.size(1)\n",
    "    size = 100\n",
    "    plot_params2d([weights], figsize=(int(size*aspect_ratio),size), binary_mask=False, gmin=sparse_min, gmax=sparse_max);\n",
    "    \n",
    "\n",
    "params_dropdown = widgets.Dropdown(description='Weights:', options=params_names, layout=Layout(width='40%'))\n",
    "interact(view_weights, pname=params_dropdown);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Convolution layers\n",
    "\n",
    "- Choose between the sparse and dense models.\n",
    "- Colors indicate intensity.\n",
    "- You can apply a binary mask to distinguish between zero and non-zero elements (zeros are white, everything else is black)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "import math \n",
    "params_names = conv_param_names(resnet20_layer_reg)\n",
    "\n",
    "def view_weights(pname, model_choice, apply_mask, color_normalization):\n",
    "    if model_choice == 'Sparse':\n",
    "        weights = resnet20_layer_reg.state_dict()[pname]\n",
    "    else:\n",
    "        weights = resnet20_dense.state_dict()[pname]\n",
    "    \n",
    "    num_kernels = weights.size(0) * weights.size(1)\n",
    "    size = int(min(num_kernels/8, 8))\n",
    "    \n",
    "    plot_param_kernels(weights, layout=(size,8), size_ctrl=2, \n",
    "                       binary_mask=apply_mask, color_normalization=color_normalization,\n",
    "                       gmin=sparse_min, gmax=sparse_max);\n",
    "\n",
    "\n",
    "model_radio = widgets.RadioButtons(options=['Sparse', 'Dense'], value='Sparse', description='Model:')\n",
    "normalize_radio = widgets.RadioButtons(options=['Group', 'Tensor', 'Model'], value='Model', description='Normalize:')\n",
    "params_dropdown = widgets.Dropdown(description='Weights:', options=params_names, layout=Layout(width='40%'))\n",
    "mask_choice = widgets.Checkbox(value=False, description='Binary mask')\n",
    "\n",
    "interact(view_weights, pname=params_dropdown, \n",
    "         model_choice=model_radio, apply_mask=mask_choice,\n",
    "         color_normalization=normalize_radio);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualize channel-wise sparsity\n",
    "\n",
    "Let's now look at the channel-wise regularized model.  Using the binary mask makes it clear which channels have been sparsified.\n",
    "\n",
    "Look for example at tensor ```module.layer1.1.conv2.weight``` which has 7 sparse channels.\n",
    "\n",
    "Then look at ```module.layer1.1.conv2.weight``` which precedes it.  If we apply the binary mask, this layer appears to be 100% dense, but once we remove the mask, we can see that this layer has 7 filters which have very low element values.\n",
    "\n",
    "To view the channel-wise sparsity of 4D weights tensors, we flatten the 4D tensors.<br>\n",
    "Weight tensors in Pytorch have a OIHW layout, which means that:<br>\n",
    "Dimension 3 - the number of output feature maps (OFMs), which is the number of filters (each filter \"generates\" a single OFM).<br>\n",
    "Dimension 2 - the number of input feature maps (IFMs), which is also the number of channels.<br>\n",
    "Dimension 1 - the height of the kernel.<br>\n",
    "Dimension 0 - the width of the kernel.<br>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# The default font size is too small, so let's increase it\n",
    "matplotlib.rcParams.update({'font.size': 32})\n",
    "\n",
    "# Let's now look at the channel-wise regularized model\n",
    "sparse_model = resnet20_channel_reg\n",
    "\n",
    "params_names = conv_param_names(sparse_model)\n",
    "\n",
    "def view_weights(pname, unused, binary_mask):\n",
    "    weights = sparse_model.state_dict()[pname]\n",
    "    weights = weights.view(weights.size(0), -1)\n",
    "    plot_params2d([weights], figsize=(50,50), binary_mask=binary_mask, xlabel=\"Channel*k*k\", ylabel=\"Filter/OFM\", gmin=sparse_min, gmax=sparse_max);\n",
    "    shape = distiller.size2str(sparse_model.state_dict()[pname].size())\n",
    "    param_info.value = shape\n",
    "\n",
    "shape =  distiller.size2str(next (iter (sparse_model.state_dict().values())).size())\n",
    "param_info = widgets.Text(value=shape, description='shape:', disabled=True)\n",
    "\n",
    "mask_choice = widgets.Checkbox(value=True, description='Binary mask')\n",
    "params_dropdown = widgets.Dropdown(description='Weights:', options=params_names, value='module.layer1.1.conv2.weight')\n",
    "interact(view_weights, pname=params_dropdown, unused=param_info, binary_mask=mask_choice);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Let's reset the font size\n",
    "matplotlib.rcParams.update({'font.size': 12})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Kernel L1-norm distributions\n",
    "Can we see indications of sparse kernels?<br>\n",
    "First, let's graph distribution of the L1-norms of the weights kernels, before and after pruning:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "def get_norms(weights, structure='kernels'):\n",
    "    \"\"\"Compute a histogram of the L1-norms of the kernels of a weights tensor.\n",
    "    \n",
    "    The L1-norm of a kernel is one way to quantify the \"magnitude\" of the total coeffiecients\n",
    "    making up this kernel.\n",
    "    \n",
    "    Another interesting look at filters is to compute a histogram per filter.\n",
    "    \"\"\"\n",
    "    ofms, ifms, kw, kh = weights.size()\n",
    "    if structure == 'kernels':\n",
    "        groups = weights.view(ofms * ifms, kh, kw)\n",
    "    elif structure == 'filters':\n",
    "        groups = weights\n",
    "    else:\n",
    "        raise ValueError('illegal structure')\n",
    "    \n",
    "    norms = [[], []]\n",
    "    groups = groups.view(groups.shape[0], -1)\n",
    "    group_size = groups.shape[1]\n",
    "    norms[0] = groups.norm(1, dim=1).div(group_size)\n",
    "    norms[1] = groups.norm(2, dim=1).div(group_size)\n",
    "    return norms\n",
    "\n",
    "def plot_l1_norm_hist(weights, structure):\n",
    "    norms = get_norms(weights, structure)\n",
    "    if structure == 'kernels':\n",
    "        bins = 200\n",
    "    else:\n",
    "        bins = 200\n",
    "    bins = None\n",
    "    n, bins, patches = plt.hist(norms[0], bins=bins, alpha=0.5)\n",
    "    plt.title('{} L1-norm histograms'.format(structure))\n",
    "    plt.ylabel('Frequency')\n",
    "    plt.xlabel('{} L1-norm'.format(structure))\n",
    "    plt.show()\n",
    "    \n",
    "def plot_kernels_norm_hist_per_filter(ax, filter, structure, norm_mag):\n",
    "    norms = get_norms(filter, \"kernels\")\n",
    "    bins = None\n",
    "    n, bins, patches = ax.hist(norms[0], alpha=0.5)\n",
    "    ax.set_xlabel('{:2f} L1-norm'.format(norm_mag))\n",
    "    \n",
    "params_names = conv_param_names(resnet20_dense)\n",
    "\n",
    "def view_kernel_l1(pname, sort_kernels):\n",
    "    tensor = resnet20_dense.state_dict()[pname].to('cpu')\n",
    "    plot_l1_norm_hist(tensor, 'kernels')\n",
    "    nrows = (tensor.shape[0]+3)//4; ncols = 4\n",
    "    f, axarr = plt.subplots(nrows, ncols, figsize=(15,7))\n",
    "    filter_norms = []\n",
    "    for i in range(0, nrows * ncols):\n",
    "        filter = tensor[i].unsqueeze(0)\n",
    "        norm = get_norms(filter, \"filters\")[0][0].item()\n",
    "        filter_norms.append((norm, i))\n",
    "    filter_norms.sort(key=lambda norm: norm[0])\n",
    "    for i in range(0, nrows * ncols):\n",
    "        filter = tensor[filter_norms[i][1]].unsqueeze(0)\n",
    "        norm_mag = filter_norms[i][0]\n",
    "        plot_kernels_norm_hist_per_filter(axarr[i//ncols, i%ncols], filter, 'kernels', norm_mag)\n",
    "    f.subplots_adjust(hspace=0.6, wspace=0.4)\n",
    "    plt.show()\n",
    "\n",
    "sort_choice = widgets.Checkbox(value=True, description='Sort kernels')\n",
    "params_dropdown = widgets.Dropdown(description='Weights:', options=params_names, layout=Layout(width='40%'))\n",
    "interact(view_kernel_l1, pname=params_dropdown, sort_kernels=sort_choice);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Compare Filter L$_1$ and L$_2$ norms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "params_names = conv_param_names(resnet20_dense)\n",
    "\n",
    "def view_weights(pname, sort, draw_l1, draw_l2):\n",
    "    param = resnet20_dense.state_dict()[pname]\n",
    "    view_filters = param.view(param.size(0), -1)\n",
    "    filter_size = view_filters.shape[1]\n",
    "    filter_mags_l1 = to_np(view_filters.norm(p=1, dim=1).div(filter_size))\n",
    "    filter_mags_l2 = to_np(view_filters.norm(p=2, dim=1).div(filter_size))\n",
    "    std_l2, mean_l2 = np.std(filter_mags_l2), np.mean(filter_mags_l2)\n",
    "    std_l1, mean_l1 = np.std(filter_mags_l1), np.mean(filter_mags_l1)\n",
    "    if sort:\n",
    "        filter_mags_l1 = np.sort(filter_mags_l1)\n",
    "        filter_mags_l2 = np.sort(filter_mags_l2)\n",
    "    plt.figure(figsize=[15,7.5])\n",
    "    if draw_l1:\n",
    "        plt.plot(range(len(filter_mags_l1)), filter_mags_l1, label=\"L1\", marker=\"o\", markersize=5, markerfacecolor=\"C1\")\n",
    "    if draw_l2:\n",
    "        plt.plot(range(len(filter_mags_l2)), filter_mags_l2, label=\"L2\", marker=\"+\", markersize=5, markerfacecolor=\"C1\")\n",
    "    plt.title(\"L1 mean: {:.4f}  std: {:.4f}\\nL2 mean: {:.4f}  std: {:.4f}\".format(mean_l1, std_l1, mean_l2, std_l2))\n",
    "    plt.xlabel('Filter index (i.e. output feature-map channel)')\n",
    "    plt.ylabel('Normalized Fliter L1-norm')\n",
    "    plt.legend()\n",
    "\n",
    "sort_choice = widgets.Checkbox(value=True, description='Sort filters')\n",
    "l1_choice = widgets.Checkbox(value=True, description='Draw L1')\n",
    "l2_choice = widgets.Checkbox(value=True, description='Draw L2')\n",
    "params_dropdown = widgets.Dropdown(description='Weights:', options=params_names, layout=Layout(width='40%'))\n",
    "interact(view_weights, pname=params_dropdown, sort=sort_choice, draw_l1=l1_choice, draw_l2=l2_choice);\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Performance metrics\n",
    "\n",
    "Let's look at the compute performance metric of our channel-regularized model.  \n",
    "\n",
    "To measure the MACs, we should look at the model after we have actually pruned the channels, because during this pruning process (which we call \"model thinning\"), we also remove the filters that produced the pruned channels."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "resnet20_channel_pruned = create_model(False, 'cifar10', 'resnet20_cifar')\n",
    "checkpoint_file = \"../examples/ssl/checkpoints/checkpoint_trained_channel_regularized_resnet20_finetuned.pth.tar\"\n",
    "load_checkpoint(resnet20_channel_pruned, checkpoint_file);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "batch_size = 1\n",
    "dummy_input = Variable(torch.randn(1, 3, 32, 32), requires_grad=False)\n",
    "df_channel_pruned = distiller.model_performance_summary(resnet20_channel_pruned, dummy_input, batch_size)\n",
    "# df_sparsity = distiller.weights_sparsity_summary(sparse_model)\n",
    "#df['sparse MACs'] = df['MACs'] * (1 - df_sparsity['Ch (%)'] / 100)\n",
    "display(df_channel_pruned)\n",
    "resnet20_channel_pruned_MACs = df_channel_pruned['MACs'].sum()\n",
    "print(\"Dense MACs: \" + \"{:,}\".format(int(resnet20_channel_pruned_MACs)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Compare sparse model compute vs. dense model compute"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_dense = distiller.model_performance_summary(resnet20_dense, dummy_input, batch_size)\n",
    "resnet20_dense_MACs = df_dense['MACs'].sum()\n",
    "\n",
    "# Model thinning - remove layers\n",
    "print(\"Removing layers with 100% sparse weight tensors:\")\n",
    "distiller.resnet_cifar_remove_layers(resnet20_layer_reg)\n",
    "\n",
    "df_layer_pruned = distiller.model_performance_summary(resnet20_layer_reg, dummy_input, 1)\n",
    "resnet20_layer_pruned_MACs = df_layer_pruned['MACs'].sum()\n",
    "\n",
    "\n",
    "print(\"Dense MACs:  %d\" % resnet20_dense_MACs)\n",
    "print(\"Channel-wise regularized MACs: %d\" % resnet20_channel_pruned_MACs)\n",
    "print(\"\\tRatio: %.2f %%\" % (resnet20_channel_pruned_MACs / resnet20_dense_MACs * 100))\n",
    "\n",
    "print(\"Layer-wise regularized MACs: %d\" % resnet20_layer_pruned_MACs)\n",
    "print(\"\\tRatio: %.2f %%\" % (resnet20_layer_pruned_MACs / resnet20_dense_MACs * 100))\n",
    "\n",
    "df_compute = pd.concat({'Dense': df_dense['MACs'], \n",
    "                        'Ch Reg': df_channel_pruned['MACs'],\n",
    "                        #'Layer Reg': df_layer_pruned['MACs']\n",
    "                       }, axis=1)\n",
    "\n",
    "ax = df_compute.plot.bar(figsize=[15,10], title=\"MACs\");\n",
    "ax.set_xticklabels(df_layer_pruned.Name, rotation=90);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Compare dense FM footprint vs. dense weights footprint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_dense['FM volume'] = df_dense['IFM volume'] + df_dense['OFM volume']\n",
    "df_footprint = df_dense[['FM volume', 'Weights volume']]\n",
    "ax = df_footprint.plot.bar(figsize=[15,10], title=\"Feature-maps footprint vs. Weights footprint\", stacked=True);\n",
    "ax.set_xticklabels(df_dense.Name, rotation=90);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## References\n",
    "\n",
    "<div id=\"Wen-et-al-2016\"></div> **Wei Wen, Chunpeng Wu, Yandan Wang, Yiran Chen, Hai Li**. \n",
    "    [*Learning Structured Sparsity in Deep Neural Networks*](https://arxiv.org/abs/1608.03665),\n",
    "    NIPS 2016."
   ]
  }
 ],
 "metadata": {
  "celltoolbar": "Raw Cell Format",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
